{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b518b04cbfe0"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "906e07f6e562"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d201e826ab29"
      },
      "source": [
        "# 자신만의 콜백 작성하기"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "71699af85d2d"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://www.tensorflow.org/guide/keras/custom_callback\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\">TensorFlow.org에서 보기</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ko/guide/keras/custom_callback.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\">Google Colab에서 실행</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ko/guide/keras/custom_callback.ipynb\">     <img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\">    GitHub에서 소스 보기</a></td>\n",
        "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ko/guide/keras/custom_callback.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\">노트북 다운로드</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d75eb2e25f36"
      },
      "source": [
        "## 시작하기\n",
        "\n",
        "콜백은 훈련, 평가 또는 추론 중에 Keras 모델의 동작을 사용자 정의할 수 있는 강력한 도구입니다. TensorBoard로 훈련 진행 상황과 결과를 시각화하기 위한 `tf.keras.callbacks.TensorBoard` 또는 훈련 도중 모델을 주기적으로 저장하는 `tf.keras.callbacks.ModelCheckpoint` 등이 여기에 포함됩니다.\n",
        "\n",
        "이 가이드에서는 Keras 콜백이 무엇인지, 무엇을 할 수 있는지, 어떻게 자신만의 콜백을 빌드할 수 있는지 배웁니다. 콜백 애플리케이션의 몇 가지 간단한 데모를 통해 시작할 수 있습니다."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b3600ee25c8e"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4dadb6688663"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "from tensorflow import keras"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "42676f705fc8"
      },
      "source": [
        "## Keras 콜백 개요\n",
        "\n",
        "모든 콜백은 `keras.callbacks.Callback` 클래스를 하위 클래스화하며, 훈련, 테스트 및 예측의 다양한 단계에서 호출되는 메서드 세트를 재정의합니다. 콜백은 훈련 중 모델의 내부 상태 및 통계를 볼 때 유용합니다.\n",
        "\n",
        "콜백(키워드 인수 `callbacks`와 같은)의 목록을 다음 모델 메서드에 전달할 수 있습니다.\n",
        "\n",
        "- `keras.Model.fit()`\n",
        "- `keras.Model.evaluate()`\n",
        "- `keras.Model.predict()`"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "46945bdf5056"
      },
      "source": [
        "## 콜백 메서드의 개요\n",
        "\n",
        "### 전역 메서드\n",
        "\n",
        "#### `on_(train|test|predict)_begin(self, logs=None)`\n",
        "\n",
        "`fit`/`evaluate`/`predict` 시작 시 호출됩니다.\n",
        "\n",
        "#### `on_(train|test|predict)_end(self, logs=None)`\n",
        "\n",
        "`fit`/`evaluate`/`predict` 종료 시 호출됩니다.\n",
        "\n",
        "### 훈련/테스트/예측을 위한 배치 레벨의 메서드\n",
        "\n",
        "#### `on_(train|test|predict)_batch_begin(self, batch, logs=None)`\n",
        "\n",
        "훈련/테스트/예측 중에 배치를 처리하기 직전에 호출됩니다.\n",
        "\n",
        "#### `on_(train|test|predict)_batch_end(self, batch, logs=None)`\n",
        "\n",
        "훈련/테스트/예측이 끝날 때 호출됩니다. 이 메서드에서 `logs`는 메트릭 결과를 포함하는 dict입니다.\n",
        "\n",
        "### 에포크 레벨 메서드(훈련만 해당)\n",
        "\n",
        "#### `on_epoch_begin(self, epoch, logs=None)`\n",
        "\n",
        "훈련 중 epoch가 시작될 때 호출됩니다.\n",
        "\n",
        "#### `on_epoch_end(self, epoch, logs=None)`\n",
        "\n",
        "훈련 중 epoc가이 끝날 때 호출됩니다."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "82f2370418a1"
      },
      "source": [
        "## 기본적인 예제\n",
        "\n",
        "구체적인 예를 살펴보겠습니다. 시작하려면 tensorflow를 가져오고 간단한 Sequential Keras 모델을 정의합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7350ea602e50"
      },
      "outputs": [],
      "source": [
        "# Define the Keras model to add callbacks to\n",
        "def get_model():\n",
        "    model = keras.Sequential()\n",
        "    model.add(keras.layers.Dense(1, input_dim=784))\n",
        "    model.compile(\n",
        "        optimizer=keras.optimizers.RMSprop(learning_rate=0.1),\n",
        "        loss=\"mean_squared_error\",\n",
        "        metrics=[\"mean_absolute_error\"],\n",
        "    )\n",
        "    return model\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "044db5f2dc6f"
      },
      "source": [
        "그런 다음 Keras 데이터세트 API에서 훈련 및 테스트용 MNIST 데이터를 로드합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f8826736a184"
      },
      "outputs": [],
      "source": [
        "# Load example MNIST data and pre-process it\n",
        "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
        "x_train = x_train.reshape(-1, 784).astype(\"float32\") / 255.0\n",
        "x_test = x_test.reshape(-1, 784).astype(\"float32\") / 255.0\n",
        "\n",
        "# Limit the data to 1000 samples\n",
        "x_train = x_train[:1000]\n",
        "y_train = y_train[:1000]\n",
        "x_test = x_test[:1000]\n",
        "y_test = y_test[:1000]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b9acd50b2215"
      },
      "source": [
        "이제 다음의 경우 로깅하는 간단한 사용자 정의 콜백을 정의합니다.\n",
        "\n",
        "- `fit`/`evaluate`/`predict`가 시작하고 끝날 때\n",
        "- 각 에포크가 시작하고 끝날 때\n",
        "- 각 훈련 배치가 시작하고 끝날 때\n",
        "- 각 평가(테스트) 배치가 시작하고 끝날 때\n",
        "- 각 추론(예측) 배치가 시작하고 끝날 때"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cc9888d28e79"
      },
      "outputs": [],
      "source": [
        "class CustomCallback(keras.callbacks.Callback):\n",
        "    def on_train_begin(self, logs=None):\n",
        "        keys = list(logs.keys())\n",
        "        print(\"Starting training; got log keys: {}\".format(keys))\n",
        "\n",
        "    def on_train_end(self, logs=None):\n",
        "        keys = list(logs.keys())\n",
        "        print(\"Stop training; got log keys: {}\".format(keys))\n",
        "\n",
        "    def on_epoch_begin(self, epoch, logs=None):\n",
        "        keys = list(logs.keys())\n",
        "        print(\"Start epoch {} of training; got log keys: {}\".format(epoch, keys))\n",
        "\n",
        "    def on_epoch_end(self, epoch, logs=None):\n",
        "        keys = list(logs.keys())\n",
        "        print(\"End epoch {} of training; got log keys: {}\".format(epoch, keys))\n",
        "\n",
        "    def on_test_begin(self, logs=None):\n",
        "        keys = list(logs.keys())\n",
        "        print(\"Start testing; got log keys: {}\".format(keys))\n",
        "\n",
        "    def on_test_end(self, logs=None):\n",
        "        keys = list(logs.keys())\n",
        "        print(\"Stop testing; got log keys: {}\".format(keys))\n",
        "\n",
        "    def on_predict_begin(self, logs=None):\n",
        "        keys = list(logs.keys())\n",
        "        print(\"Start predicting; got log keys: {}\".format(keys))\n",
        "\n",
        "    def on_predict_end(self, logs=None):\n",
        "        keys = list(logs.keys())\n",
        "        print(\"Stop predicting; got log keys: {}\".format(keys))\n",
        "\n",
        "    def on_train_batch_begin(self, batch, logs=None):\n",
        "        keys = list(logs.keys())\n",
        "        print(\"...Training: start of batch {}; got log keys: {}\".format(batch, keys))\n",
        "\n",
        "    def on_train_batch_end(self, batch, logs=None):\n",
        "        keys = list(logs.keys())\n",
        "        print(\"...Training: end of batch {}; got log keys: {}\".format(batch, keys))\n",
        "\n",
        "    def on_test_batch_begin(self, batch, logs=None):\n",
        "        keys = list(logs.keys())\n",
        "        print(\"...Evaluating: start of batch {}; got log keys: {}\".format(batch, keys))\n",
        "\n",
        "    def on_test_batch_end(self, batch, logs=None):\n",
        "        keys = list(logs.keys())\n",
        "        print(\"...Evaluating: end of batch {}; got log keys: {}\".format(batch, keys))\n",
        "\n",
        "    def on_predict_batch_begin(self, batch, logs=None):\n",
        "        keys = list(logs.keys())\n",
        "        print(\"...Predicting: start of batch {}; got log keys: {}\".format(batch, keys))\n",
        "\n",
        "    def on_predict_batch_end(self, batch, logs=None):\n",
        "        keys = list(logs.keys())\n",
        "        print(\"...Predicting: end of batch {}; got log keys: {}\".format(batch, keys))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8184bd3a76c2"
      },
      "source": [
        "사용해 보겠습니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "75f7aa1edac6"
      },
      "outputs": [],
      "source": [
        "model = get_model()\n",
        "model.fit(\n",
        "    x_train,\n",
        "    y_train,\n",
        "    batch_size=128,\n",
        "    epochs=1,\n",
        "    verbose=0,\n",
        "    validation_split=0.5,\n",
        "    callbacks=[CustomCallback()],\n",
        ")\n",
        "\n",
        "res = model.evaluate(\n",
        "    x_test, y_test, batch_size=128, verbose=0, callbacks=[CustomCallback()]\n",
        ")\n",
        "\n",
        "res = model.predict(x_test, batch_size=128, callbacks=[CustomCallback()])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "02113b8677eb"
      },
      "source": [
        "### `logs` dict 사용법\n",
        "\n",
        "`logs` dict에는 손실값과 배치 또는 에포크의 끝에 있는 모든 메트릭이 포함됩니다. 이 예제에는 손실 및 평균 절대 오차가 포함됩니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2c5e71af7abe"
      },
      "outputs": [],
      "source": [
        "class LossAndErrorPrintingCallback(keras.callbacks.Callback):\n",
        "    def on_train_batch_end(self, batch, logs=None):\n",
        "        print(\"For batch {}, loss is {:7.2f}.\".format(batch, logs[\"loss\"]))\n",
        "\n",
        "    def on_test_batch_end(self, batch, logs=None):\n",
        "        print(\"For batch {}, loss is {:7.2f}.\".format(batch, logs[\"loss\"]))\n",
        "\n",
        "    def on_epoch_end(self, epoch, logs=None):\n",
        "        print(\n",
        "            \"The average loss for epoch {} is {:7.2f} \"\n",
        "            \"and mean absolute error is {:7.2f}.\".format(\n",
        "                epoch, logs[\"loss\"], logs[\"mean_absolute_error\"]\n",
        "            )\n",
        "        )\n",
        "\n",
        "\n",
        "model = get_model()\n",
        "model.fit(\n",
        "    x_train,\n",
        "    y_train,\n",
        "    batch_size=128,\n",
        "    epochs=2,\n",
        "    verbose=0,\n",
        "    callbacks=[LossAndErrorPrintingCallback()],\n",
        ")\n",
        "\n",
        "res = model.evaluate(\n",
        "    x_test,\n",
        "    y_test,\n",
        "    batch_size=128,\n",
        "    verbose=0,\n",
        "    callbacks=[LossAndErrorPrintingCallback()],\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "742d62e5394a"
      },
      "source": [
        "## `self.model` 속성의 사용법\n",
        "\n",
        "메서드 중 하나가 호출될 때 로그 정보를 수신하는 것 외에도 콜백은 현재 훈련/평가/추론 라운드와 연결된 모델(`self.model`)에 액세스할 수 있습니다.\n",
        "\n",
        "콜백에서 `self.model`로 수행할 수 있는 연산은 다음과 같습니다.\n",
        "\n",
        "- 훈련을 즉시 중단하려면 `self.model.stop_training = True`를 설정합니다.\n",
        "- `self.model.optimizer.learning_rate`와 같은 옵티마이저(`self.model.optimizer`로 사용 가능)의 하이퍼파라미터를 변경합니다.\n",
        "- 주기적으로 모델을 저장합니다.\n",
        "- 각 에포크가 끝날 때 몇 가지 테스트 샘플에 `model.predict()`의 출력을 기록하여 훈련 중에 온전성 검사용으로 사용합니다.\n",
        "- 각 에포크가 끝날 때 중간 특성의 시각화를 추출하여 시간이 지남에 따라 모델이 학습하는 내용을 모니터링합니다.\n",
        "- 기타\n",
        "\n",
        "몇 가지 실제 예를 살펴보겠습니다."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7eb29d3ed752"
      },
      "source": [
        "## Keras 콜백 애플리케이션의 예"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2d1d29d99fa5"
      },
      "source": [
        "### 최소 손실 시 조기 중지\n",
        "\n",
        "이 첫 번째 예는 `self.model.stop_training` (boolean) 속성을 설정하여 최소 손실에 도달했을 때 훈련을 중단하는 `Callback`을 생성하는 방법을 보여줍니다. 선택적으로, 로컬 최소값에 도달한 후 중단하기 전에 기다려야 하는 에포크 수를 지정하는 인수 `patience`을 제공할 수 있습니다.\n",
        "\n",
        "`tf.keras.callbacks.EarlyStopping`은 더 완전한 일반적인 구현을 제공합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5d2acd79cecd"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "\n",
        "\n",
        "class EarlyStoppingAtMinLoss(keras.callbacks.Callback):\n",
        "    \"\"\"Stop training when the loss is at its min, i.e. the loss stops decreasing.\n",
        "\n",
        "  Arguments:\n",
        "      patience: Number of epochs to wait after min has been hit. After this\n",
        "      number of no improvement, training stops.\n",
        "  \"\"\"\n",
        "\n",
        "    def __init__(self, patience=0):\n",
        "        super(EarlyStoppingAtMinLoss, self).__init__()\n",
        "        self.patience = patience\n",
        "        # best_weights to store the weights at which the minimum loss occurs.\n",
        "        self.best_weights = None\n",
        "\n",
        "    def on_train_begin(self, logs=None):\n",
        "        # The number of epoch it has waited when loss is no longer minimum.\n",
        "        self.wait = 0\n",
        "        # The epoch the training stops at.\n",
        "        self.stopped_epoch = 0\n",
        "        # Initialize the best as infinity.\n",
        "        self.best = np.Inf\n",
        "\n",
        "    def on_epoch_end(self, epoch, logs=None):\n",
        "        current = logs.get(\"loss\")\n",
        "        if np.less(current, self.best):\n",
        "            self.best = current\n",
        "            self.wait = 0\n",
        "            # Record the best weights if current results is better (less).\n",
        "            self.best_weights = self.model.get_weights()\n",
        "        else:\n",
        "            self.wait += 1\n",
        "            if self.wait >= self.patience:\n",
        "                self.stopped_epoch = epoch\n",
        "                self.model.stop_training = True\n",
        "                print(\"Restoring model weights from the end of the best epoch.\")\n",
        "                self.model.set_weights(self.best_weights)\n",
        "\n",
        "    def on_train_end(self, logs=None):\n",
        "        if self.stopped_epoch > 0:\n",
        "            print(\"Epoch %05d: early stopping\" % (self.stopped_epoch + 1))\n",
        "\n",
        "\n",
        "model = get_model()\n",
        "model.fit(\n",
        "    x_train,\n",
        "    y_train,\n",
        "    batch_size=64,\n",
        "    steps_per_epoch=5,\n",
        "    epochs=30,\n",
        "    verbose=0,\n",
        "    callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()],\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "939ecfbe0383"
      },
      "source": [
        "### 학습 속도 스케줄링\n",
        "\n",
        "이 예제에서는 사용자 정의 콜백을 사용하여 훈련 동안 옵티마이저의 학습 속도를 동적으로 변경하는 방법을 보여줍니다.\n",
        "\n",
        "보다 일반적인 구현에 대해서는 `callbacks.LearningRateScheduler`를 참조하세요."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "71c752b248c0"
      },
      "outputs": [],
      "source": [
        "class CustomLearningRateScheduler(keras.callbacks.Callback):\n",
        "    \"\"\"Learning rate scheduler which sets the learning rate according to schedule.\n",
        "\n",
        "  Arguments:\n",
        "      schedule: a function that takes an epoch index\n",
        "          (integer, indexed from 0) and current learning rate\n",
        "          as inputs and returns a new learning rate as output (float).\n",
        "  \"\"\"\n",
        "\n",
        "    def __init__(self, schedule):\n",
        "        super(CustomLearningRateScheduler, self).__init__()\n",
        "        self.schedule = schedule\n",
        "\n",
        "    def on_epoch_begin(self, epoch, logs=None):\n",
        "        if not hasattr(self.model.optimizer, \"lr\"):\n",
        "            raise ValueError('Optimizer must have a \"lr\" attribute.')\n",
        "        # Get the current learning rate from model's optimizer.\n",
        "        lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate))\n",
        "        # Call schedule function to get the scheduled learning rate.\n",
        "        scheduled_lr = self.schedule(epoch, lr)\n",
        "        # Set the value back to the optimizer before this epoch starts\n",
        "        tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)\n",
        "        print(\"\\nEpoch %05d: Learning rate is %6.4f.\" % (epoch, scheduled_lr))\n",
        "\n",
        "\n",
        "LR_SCHEDULE = [\n",
        "    # (epoch to start, learning rate) tuples\n",
        "    (3, 0.05),\n",
        "    (6, 0.01),\n",
        "    (9, 0.005),\n",
        "    (12, 0.001),\n",
        "]\n",
        "\n",
        "\n",
        "def lr_schedule(epoch, lr):\n",
        "    \"\"\"Helper function to retrieve the scheduled learning rate based on epoch.\"\"\"\n",
        "    if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:\n",
        "        return lr\n",
        "    for i in range(len(LR_SCHEDULE)):\n",
        "        if epoch == LR_SCHEDULE[i][0]:\n",
        "            return LR_SCHEDULE[i][1]\n",
        "    return lr\n",
        "\n",
        "\n",
        "model = get_model()\n",
        "model.fit(\n",
        "    x_train,\n",
        "    y_train,\n",
        "    batch_size=64,\n",
        "    steps_per_epoch=5,\n",
        "    epochs=15,\n",
        "    verbose=0,\n",
        "    callbacks=[\n",
        "        LossAndErrorPrintingCallback(),\n",
        "        CustomLearningRateScheduler(lr_schedule),\n",
        "    ],\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c9be225b57f1"
      },
      "source": [
        "### 내장 Keras 콜백\n",
        "\n",
        "[API 문서](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/)를 읽고 기존 Keras 콜백을 확인하세요. 애플리케이션에는 CSV에 로깅하기, 모델 저장하기, TensorBoard에서 메트릭 시각화하기 등이 포함됩니다."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "custom_callback.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
